package depth_first_search;

import java.util.*;

/**
 * Created by gouthamvidyapradhan on 20/08/2019 In a network of nodes, each node i is directly
 * connected to another node j if and only if graph[i][j] = 1.
 *
 * <p>Some nodes initial are initially infected by malware. Whenever two nodes are directly
 * connected and at least one of those two nodes is infected by malware, both nodes will be infected
 * by malware. This spread of malware will continue until no more nodes can be infected in this
 * manner.
 *
 * <p>Suppose M(initial) is the final number of nodes infected with malware in the entire network,
 * after the spread of malware stops.
 *
 * <p>We will remove one node from the initial list. Return the node that if removed, would minimize
 * M(initial). If multiple nodes could be removed to minimize M(initial), return such a node with
 * the smallest index.
 *
 * <p>Note that if a node was removed from the initial list of infected nodes, it may still be
 * infected later as a result of the malware spread.
 *
 * <p>Example 1:
 *
 * <p>Input: graph = [[1,1,0],[1,1,0],[0,0,1]], initial = [0,1] Output: 0 Example 2:
 *
 * <p>Input: graph = [[1,0,0],[0,1,0],[0,0,1]], initial = [0,2] Output: 0 Example 3:
 *
 * <p>Input: graph = [[1,1,1],[1,1,1],[1,1,1]], initial = [1,2] Output: 1
 *
 * <p>Note:
 *
 * <p>1 < graph.length = graph[0].length <= 300 0 <= graph[i][j] == graph[j][i] <= 1 graph[i][i] = 1
 * 1 <= initial.length < graph.length 0 <= initial[i] < graph.length
 *
 * <p>Solution: O(N x M x I) + O(I ^ 2) where N x M is number of nodes and I is the size of the
 * initial Do a dfs from each of the initial nodes and color the reachable nodes with color i (color
 * of initial node) and keep track of count of nodes reachable by this node - do not re-visit the
 * already visited nodes. Check the list of initial nodes which have unique color which no other
 * initial nodes have and mark this as eligible candidate. Sort the eligible candidates by pick the
 * candidate which has maximum count of nodes reachable from it.
 */
public class MinimizeMalwareSpread {
  public static void main(String[] args) {
    int[][] graph = {
      {1, 0, 0, 0, 0, 0},
      {0, 1, 0, 0, 0, 0},
      {0, 0, 1, 0, 0, 0},
      {0, 0, 0, 1, 1, 0},
      {0, 0, 0, 1, 1, 0},
      {0, 0, 0, 0, 0, 1}
    };
    int[] i = {5, 0};
    new MinimizeMalwareSpread().minMalwareSpread(graph, i);
  }

  Map<Integer, List<Integer>> graphMap;
  Map<Integer, Integer> size;
  Set<Integer> done;
  Map<Integer, Integer> color;
  int count = 0;

  public int minMalwareSpread(int[][] graph, int[] initial) {
    graphMap = new HashMap<>();
    done = new HashSet<>();
    color = new HashMap<>();
    size = new HashMap<>();
    for (int i = 0; i < graph.length; i++) {
      for (int j = 0; j < graph[0].length; j++) {
        if (graph[i][j] == 1) {
          graphMap.putIfAbsent(i, new ArrayList<>());
          graphMap.get(i).add(j);
          graphMap.putIfAbsent(j, new ArrayList<>());
          graphMap.get(j).add(i);
        }
      }
    }
    for (int i : initial) {
      if (!done.contains(i)) {
        count = 0;
        dfs(i, i);
        size.put(i, count);
      }
    }
    List<Integer> eligible = new ArrayList<>();
    boolean candidate;
    for (int i = 0; i < initial.length; i++) {
      int iColor = color.get(initial[i]);
      candidate = true;
      for (int j = 0; j < initial.length; j++) {
        if (j != i) {
          if (color.get(initial[j]) == iColor) {
            candidate = false;
            break;
          }
        }
      }
      if (candidate) {
        eligible.add(initial[i]);
      }
    }
    Arrays.sort(initial);
    eligible.sort(Comparator.comparingInt(o -> o));
    if (eligible.isEmpty()) {
      return initial[0];
    } else {
      int answer = initial[0];
      int max = 0;
      for (int i = 0, l = eligible.size(); i < l; i++) {
        int node = eligible.get(i);
        if (size.containsKey(node)) {
          if (size.get(node) > max) {
            max = size.get(node);
            answer = node;
          }
        }
      }
      return answer;
    }
  }

  private void dfs(int i, int col) {
    done.add(i);
    color.put(i, col);
    count++;
    List<Integer> children = graphMap.get(i);
    if (children != null && !children.isEmpty()) {
      for (int c : children) {
        if (!done.contains(c)) {
          dfs(c, col);
        }
      }
    }
  }
}
